{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 导入相关库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.utils import to_categorical\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Conv2D, AveragePooling2D, Flatten, Dense\n",
    "from keras.datasets import cifar10, mnist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载CIFAR10数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((50000, 32, 32, 3), (50000, 10), (10000, 32, 32, 3), (10000, 10))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(X_train, y_train), (X_test, y_test) = cifar10.load_data()\n",
    "X_train = X_train / 255.0\n",
    "X_test = X_test / 255.0\n",
    "y_train, y_test = to_categorical(y_train), to_categorical(y_test)\n",
    "X_train.shape, y_train.shape, X_test.shape, y_test.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 搭建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\Programs\\Anaconda\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:4074: The name tf.nn.avg_pool is deprecated. Please use tf.nn.avg_pool2d instead.\n",
      "\n",
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_1 (Conv2D)            (None, 32, 32, 6)         18        \n",
      "_________________________________________________________________\n",
      "average_pooling2d_1 (Average (None, 16, 16, 6)         0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 16, 16, 16)        96        \n",
      "_________________________________________________________________\n",
      "average_pooling2d_2 (Average (None, 8, 8, 16)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 8, 8, 120)         1920      \n",
      "_________________________________________________________________\n",
      "average_pooling2d_3 (Average (None, 4, 4, 120)         0         \n",
      "_________________________________________________________________\n",
      "flatten_1 (Flatten)          (None, 1920)              0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 120)               230400    \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 84)                10080     \n",
      "_________________________________________________________________\n",
      "dense_3 (Dense)              (None, 10)                840       \n",
      "=================================================================\n",
      "Total params: 243,354\n",
      "Trainable params: 243,354\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Conv2D(6, (1, 1), activation='relu', use_bias=False, input_shape=(32, 32, 3)))\n",
    "model.add(AveragePooling2D((2, 2)))\n",
    "model.add(Conv2D(16, (1, 1), activation='relu', use_bias=False))\n",
    "model.add(AveragePooling2D((2, 2)))\n",
    "model.add(Conv2D(120, (1, 1), activation='relu', use_bias=False))\n",
    "model.add(AveragePooling2D((2, 2)))\n",
    "model.add(Flatten())\n",
    "model.add(Dense(120, activation='relu', use_bias=False))\n",
    "model.add(Dense(84, activation='relu', use_bias=False))\n",
    "model.add(Dense(10, activation='softmax', use_bias=False))\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\Programs\\Anaconda\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.\n",
      "\n",
      "Train on 50000 samples, validate on 10000 samples\n",
      "Epoch 1/10\n",
      "50000/50000 [==============================] - 234s 5ms/step - loss: 1.8726 - accuracy: 0.3102 - val_loss: 1.6338 - val_accuracy: 0.4065\n",
      "Epoch 2/10\n",
      "50000/50000 [==============================] - 209s 4ms/step - loss: 1.5938 - accuracy: 0.4274 - val_loss: 1.5292 - val_accuracy: 0.4527\n",
      "Epoch 3/10\n",
      "50000/50000 [==============================] - 218s 4ms/step - loss: 1.5125 - accuracy: 0.4584 - val_loss: 1.5179 - val_accuracy: 0.4551\n",
      "Epoch 4/10\n",
      "50000/50000 [==============================] - 230s 5ms/step - loss: 1.4701 - accuracy: 0.4761 - val_loss: 1.4945 - val_accuracy: 0.4675\n",
      "Epoch 5/10\n",
      "50000/50000 [==============================] - 264s 5ms/step - loss: 1.4382 - accuracy: 0.4875 - val_loss: 1.4463 - val_accuracy: 0.4844\n",
      "Epoch 6/10\n",
      "50000/50000 [==============================] - 250s 5ms/step - loss: 1.4127 - accuracy: 0.4943 - val_loss: 1.4266 - val_accuracy: 0.4923\n",
      "Epoch 7/10\n",
      "50000/50000 [==============================] - 697s 14ms/step - loss: 1.3904 - accuracy: 0.5040 - val_loss: 1.4118 - val_accuracy: 0.4997\n",
      "Epoch 8/10\n",
      "50000/50000 [==============================] - 281s 6ms/step - loss: 1.3723 - accuracy: 0.5108 - val_loss: 1.4087 - val_accuracy: 0.4988\n",
      "Epoch 9/10\n",
      "50000/50000 [==============================] - 186s 4ms/step - loss: 1.3542 - accuracy: 0.5187 - val_loss: 1.3940 - val_accuracy: 0.4995\n",
      "Epoch 10/10\n",
      "50000/50000 [==============================] - 190s 4ms/step - loss: 1.3378 - accuracy: 0.5255 - val_loss: 1.3896 - val_accuracy: 0.5108\n"
     ]
    }
   ],
   "source": [
    "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
    "history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 保存结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('GDCNN_CIFAR10.txt', 'a+') as f:\n",
    "    info = history.history\n",
    "    cols = list(info.keys())\n",
    "    text = '\\t'.join(cols) + '\\n'\n",
    "    for i in range(len(info['loss'])):\n",
    "        text += ('\\t'.join(['{:.4f}'.format(info[c][i]) for c in cols]) + '\\n')\n",
    "    f.write(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "***"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载MNIST数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((60000, 28, 28, 1), (60000, 10), (10000, 28, 28, 1), (10000, 10))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
    "X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')\n",
    "X_test = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')\n",
    "X_train /= 255.0\n",
    "X_test /= 255.0\n",
    "y_train, y_test = to_categorical(y_train), to_categorical(y_test)\n",
    "X_train.shape, y_train.shape, X_test.shape, y_test.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 搭建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\Programs\\Anaconda\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:4074: The name tf.nn.avg_pool is deprecated. Please use tf.nn.avg_pool2d instead.\n",
      "\n",
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d_1 (Conv2D)            (None, 28, 28, 40)        40        \n",
      "_________________________________________________________________\n",
      "average_pooling2d_1 (Average (None, 14, 14, 40)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_2 (Conv2D)            (None, 14, 14, 40)        1600      \n",
      "_________________________________________________________________\n",
      "average_pooling2d_2 (Average (None, 7, 7, 40)          0         \n",
      "_________________________________________________________________\n",
      "conv2d_3 (Conv2D)            (None, 7, 7, 5)           200       \n",
      "_________________________________________________________________\n",
      "average_pooling2d_3 (Average (None, 3, 3, 5)           0         \n",
      "_________________________________________________________________\n",
      "conv2d_4 (Conv2D)            (None, 3, 3, 1)           5         \n",
      "_________________________________________________________________\n",
      "flatten_1 (Flatten)          (None, 9)                 0         \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 40)                360       \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 10)                400       \n",
      "=================================================================\n",
      "Total params: 2,605\n",
      "Trainable params: 2,605\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = Sequential()\n",
    "model.add(Conv2D(40, (1, 1), activation='relu', use_bias=False, input_shape=(28, 28, 1)))\n",
    "model.add(AveragePooling2D((2, 2)))\n",
    "model.add(Conv2D(40, (1, 1), activation='relu', use_bias=False))\n",
    "model.add(AveragePooling2D((2, 2)))\n",
    "model.add(Conv2D(5, (1, 1), activation='relu', use_bias=False))\n",
    "model.add(AveragePooling2D((2, 2)))\n",
    "model.add(Conv2D(1, (1, 1), activation='relu', use_bias=False))\n",
    "model.add(Flatten())\n",
    "model.add(Dense(40, activation='relu', use_bias=False))\n",
    "model.add(Dense(10, activation='softmax', use_bias=False))\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From D:\\Programs\\Anaconda\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.\n",
      "\n",
      "Train on 60000 samples, validate on 10000 samples\n",
      "Epoch 1/10\n",
      "60000/60000 [==============================] - 329s 5ms/step - loss: 1.1671 - accuracy: 0.6082 - val_loss: 0.9559 - val_accuracy: 0.6797\n",
      "Epoch 2/10\n",
      "60000/60000 [==============================] - 335s 6ms/step - loss: 0.9316 - accuracy: 0.6888 - val_loss: 0.9014 - val_accuracy: 0.6958\n",
      "Epoch 3/10\n",
      "60000/60000 [==============================] - 293s 5ms/step - loss: 0.8805 - accuracy: 0.7022 - val_loss: 0.8484 - val_accuracy: 0.7153\n",
      "Epoch 4/10\n",
      "60000/60000 [==============================] - 288s 5ms/step - loss: 0.8468 - accuracy: 0.7150 - val_loss: 0.8295 - val_accuracy: 0.7192\n",
      "Epoch 5/10\n",
      "60000/60000 [==============================] - 280s 5ms/step - loss: 0.8244 - accuracy: 0.7201 - val_loss: 0.8199 - val_accuracy: 0.7141\n",
      "Epoch 6/10\n",
      "60000/60000 [==============================] - 266s 4ms/step - loss: 0.8076 - accuracy: 0.7274 - val_loss: 0.8022 - val_accuracy: 0.7309\n",
      "Epoch 7/10\n",
      "60000/60000 [==============================] - 263s 4ms/step - loss: 0.7936 - accuracy: 0.7306 - val_loss: 0.7976 - val_accuracy: 0.7224\n",
      "Epoch 8/10\n",
      "60000/60000 [==============================] - 265s 4ms/step - loss: 0.7842 - accuracy: 0.7326 - val_loss: 0.7664 - val_accuracy: 0.7401\n",
      "Epoch 9/10\n",
      "60000/60000 [==============================] - 266s 4ms/step - loss: 0.7734 - accuracy: 0.7355 - val_loss: 0.7810 - val_accuracy: 0.7282\n",
      "Epoch 10/10\n",
      "60000/60000 [==============================] - 361s 6ms/step - loss: 0.7677 - accuracy: 0.7371 - val_loss: 0.7708 - val_accuracy: 0.7385\n"
     ]
    }
   ],
   "source": [
    "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
    "history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 保存结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "with open('GDCNN_MNIST.txt', 'a+') as f:\n",
    "    info = history.history\n",
    "    cols = list(info.keys())\n",
    "    text = '\\t'.join(cols) + '\\n'\n",
    "    for i in range(len(info['loss'])):\n",
    "        text += ('\\t'.join(['{:.4f}'.format(info[c][i]) for c in cols]) + '\\n')\n",
    "    f.write(text)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
